{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(160, 15, 2560)\n",
      "(160,)\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "Format dataset, we read the file for the desired subject, and parse the data to extract:\n",
    "- samplingRate\n",
    "- trialLength\n",
    "- X, a M x N x K matrix, which stands for trial x chan x samples\n",
    "                         the actual values are 160 x 15 x 2560\n",
    "- y, a M vector containing the labels {0,1}\n",
    "\n",
    "ref:\n",
    "Dataset description: https://lampx.tugraz.at/~bci/database/002-2014/description.pdf\n",
    "\"\"\"\n",
    "\n",
    "import scipy.io as sio\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# prepare data containers\n",
    "y = []\n",
    "X = []\n",
    "\n",
    "\"\"\"\n",
    "trainingFileList = [#'BBCIData/S14T.mat', \n",
    "                    #'BBCIData/S13T.mat', \n",
    "                    #'BBCIData/S12T.mat', \n",
    "                    #'BBCIData/S11T.mat', \n",
    "                    #'BBCIData/S10T.mat', \n",
    "                    #'BBCIData/S09T.mat', \n",
    "                    #'BBCIData/S08T.mat', \n",
    "                    #'BBCIData/S07T.mat', \n",
    "                    #'BBCIData/S06T.mat', \n",
    "                    #'BBCIData/S05T.mat', \n",
    "                    #'BBCIData/S04T.mat', \n",
    "                    #'BBCIData/S03T.mat', \n",
    "                    #'BBCIData/S02T.mat', \n",
    "                    'BBCIData/S01T.mat']\n",
    "\n",
    "validationFileList = [#'BBCIData/S14E.mat', \n",
    "                      #'BBCIData/S13E.mat', \n",
    "                      #'BBCIData/S12E.mat', \n",
    "                      #'BBCIData/S11E.mat', \n",
    "                      #'BBCIData/S10E.mat', \n",
    "                      #'BBCIData/S09E.mat', \n",
    "                      #'BBCIData/S08E.mat', \n",
    "                      #'BBCIData/S07E.mat', \n",
    "                      #'BBCIData/S06E.mat', \n",
    "                      #'BBCIData/S05E.mat', \n",
    "                      #'BBCIData/S04E.mat', \n",
    "                      #'BBCIData/S03E.mat', \n",
    "                      #'BBCIData/S02E.mat', \n",
    "                      'BBCIData/S01E.mat']\n",
    "\"\"\"\n",
    "\n",
    "trainingFileList = ['BBCIData/S08T.mat']\n",
    "\n",
    "validationFileList = ['BBCIData/S08E.mat']\n",
    "\n",
    "for i in range(len(trainingFileList)):\n",
    "    # read file\n",
    "    d1T = sio.loadmat(trainingFileList[i])\n",
    "    d1E = sio.loadmat(validationFileList[i])\n",
    "    \n",
    "    samplingRate = d1T['data'][0][0][0][0][3][0][0]\n",
    "    trialLength = 5*samplingRate\n",
    "\n",
    "\n",
    "    # run through all training runs\n",
    "    for run in range(5):\n",
    "        y.append(d1T['data'][0][run][0][0][2][0]) # labels\n",
    "        timestamps = d1T['data'][0][run][0][0][1][0] # timestamps\n",
    "        rawData = d1T['data'][0][run][0][0][0].transpose() # chan x data\n",
    "\n",
    "        # parse out data based on timestamps\n",
    "        for start in timestamps:\n",
    "            end = start + trialLength\n",
    "            X.append(rawData[:,start:end]) #15 x 2560\n",
    "\n",
    "\n",
    "    # run through all validation runs (we do not discriminate at this point)\n",
    "    for run in range(3):\n",
    "        y.append(d1E['data'][0][run][0][0][2][0]) # labels\n",
    "        timestamps = d1E['data'][0][run][0][0][1][0] # timestamps\n",
    "        rawData = d1E['data'][0][run][0][0][0].transpose() # chan x data\n",
    "\n",
    "        # parse out data based on timestamps\n",
    "        for start in timestamps:\n",
    "            end = start + trialLength\n",
    "            X.append(rawData[:,start:end]) #15 x 2560\n",
    "\n",
    "    del rawData\n",
    "    del d1T\n",
    "    del d1E\n",
    "\n",
    "# arrange data into numpy arrays\n",
    "# also torch expect float32 for samples\n",
    "# and int64 for labels {0,1}\n",
    "X = np.array(X).astype(np.float32)\n",
    "y = (np.array(y).flatten()-1).astype(np.int64)\n",
    "print(X.shape)\n",
    "print(y.shape)\n",
    "\n",
    "# erase unused references\n",
    "d1T = []\n",
    "d1E = []\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "from braindecode.datautil.signal_target import SignalAndTarget\n",
    "from braindecode.models.shallow_fbcsp import ShallowFBCSPNet\n",
    "from torch import nn\n",
    "from braindecode.torch_ext.util import set_random_seeds    \n",
    "from torch import optim\n",
    "import torch\n",
    "\n",
    "idx = np.random.permutation(X.shape[0])\n",
    "\n",
    "X = X[idx,:,:]\n",
    "y = y[idx]\n",
    "\n",
    "#print(X.shape)\n",
    "#print(y.shape)\n",
    "\n",
    "nb_train_trials = int(np.floor(5/8*X.shape[0]))\n",
    "\n",
    "\n",
    "train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])\n",
    "test_set = SignalAndTarget(X[nb_train_trials:], y=y[nb_train_trials:])\n",
    "\n",
    "#train_set = SignalAndTarget(X[:nb_train_trials], y=y[:nb_train_trials])\n",
    "#test_set = SignalAndTarget(X[nb_train_trials:nb_test_trials], y=y[nb_train_trials:nb_test_trials])\n",
    "\n",
    "# Set if you want to use GPU\n",
    "# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.\n",
    "cuda = torch.cuda.is_available()\n",
    "set_random_seeds(seed=20170629, cuda=cuda)\n",
    "n_classes = 2\n",
    "in_chans = train_set.X.shape[1]\n",
    "# final_conv_length = auto ensures we only get a single output in the time dimension\n",
    "model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,\n",
    "                        input_time_length=train_set.X.shape[2],\n",
    "                        final_conv_length='auto').create_network()\n",
    "if cuda:\n",
    "    model.cuda()\n",
    "\n",
    "optimizer = optim.Adam(model.parameters())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0\n",
      "Train  Loss: 23.12267\n",
      "Train  Accuracy: 51.0%\n",
      "Test   Loss: 23.45116\n",
      "Test   Accuracy: 46.7%\n",
      "Epoch 1\n",
      "Train  Loss: 11.42247\n",
      "Train  Accuracy: 50.0%\n",
      "Test   Loss: 11.18545\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 2\n",
      "Train  Loss: 6.14801\n",
      "Train  Accuracy: 49.0%\n",
      "Test   Loss: 5.97394\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 3\n",
      "Train  Loss: 7.08030\n",
      "Train  Accuracy: 51.0%\n",
      "Test   Loss: 7.17894\n",
      "Test   Accuracy: 46.7%\n",
      "Epoch 4\n",
      "Train  Loss: 2.78045\n",
      "Train  Accuracy: 51.0%\n",
      "Test   Loss: 2.80417\n",
      "Test   Accuracy: 48.3%\n",
      "Epoch 5\n",
      "Train  Loss: 3.10231\n",
      "Train  Accuracy: 49.0%\n",
      "Test   Loss: 3.01706\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 6\n",
      "Train  Loss: 0.77492\n",
      "Train  Accuracy: 59.0%\n",
      "Test   Loss: 0.73021\n",
      "Test   Accuracy: 55.0%\n",
      "Epoch 7\n",
      "Train  Loss: 2.08993\n",
      "Train  Accuracy: 53.0%\n",
      "Test   Loss: 2.06789\n",
      "Test   Accuracy: 48.3%\n",
      "Epoch 8\n",
      "Train  Loss: 0.78730\n",
      "Train  Accuracy: 64.0%\n",
      "Test   Loss: 0.69451\n",
      "Test   Accuracy: 58.3%\n",
      "Epoch 9\n",
      "Train  Loss: 0.72214\n",
      "Train  Accuracy: 65.0%\n",
      "Test   Loss: 0.62667\n",
      "Test   Accuracy: 58.3%\n",
      "Epoch 10\n",
      "Train  Loss: 1.04905\n",
      "Train  Accuracy: 57.0%\n",
      "Test   Loss: 0.94240\n",
      "Test   Accuracy: 43.3%\n",
      "Epoch 11\n",
      "Train  Loss: 1.06353\n",
      "Train  Accuracy: 62.0%\n",
      "Test   Loss: 0.92665\n",
      "Test   Accuracy: 63.3%\n",
      "Epoch 12\n",
      "Train  Loss: 0.73815\n",
      "Train  Accuracy: 67.0%\n",
      "Test   Loss: 0.61124\n",
      "Test   Accuracy: 60.0%\n",
      "Epoch 13\n",
      "Train  Loss: 1.21130\n",
      "Train  Accuracy: 56.0%\n",
      "Test   Loss: 1.06474\n",
      "Test   Accuracy: 46.7%\n",
      "Epoch 14\n",
      "Train  Loss: 0.57069\n",
      "Train  Accuracy: 73.0%\n",
      "Test   Loss: 0.44313\n",
      "Test   Accuracy: 56.7%\n",
      "Epoch 15\n",
      "Train  Loss: 0.95924\n",
      "Train  Accuracy: 62.0%\n",
      "Test   Loss: 0.83465\n",
      "Test   Accuracy: 58.3%\n",
      "Epoch 16\n",
      "Train  Loss: 0.61197\n",
      "Train  Accuracy: 68.0%\n",
      "Test   Loss: 0.45661\n",
      "Test   Accuracy: 45.0%\n",
      "Epoch 17\n",
      "Train  Loss: 0.68676\n",
      "Train  Accuracy: 67.0%\n",
      "Test   Loss: 0.52255\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 18\n",
      "Train  Loss: 0.59275\n",
      "Train  Accuracy: 71.0%\n",
      "Test   Loss: 0.49090\n",
      "Test   Accuracy: 61.7%\n",
      "Epoch 19\n",
      "Train  Loss: 0.46892\n",
      "Train  Accuracy: 78.0%\n",
      "Test   Loss: 0.37154\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 20\n",
      "Train  Loss: 0.87564\n",
      "Train  Accuracy: 63.0%\n",
      "Test   Loss: 0.72043\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 21\n",
      "Train  Loss: 0.54479\n",
      "Train  Accuracy: 73.0%\n",
      "Test   Loss: 0.42003\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 22\n",
      "Train  Loss: 0.47825\n",
      "Train  Accuracy: 80.0%\n",
      "Test   Loss: 0.40232\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 23\n",
      "Train  Loss: 0.47127\n",
      "Train  Accuracy: 80.0%\n",
      "Test   Loss: 0.39606\n",
      "Test   Accuracy: 55.0%\n",
      "Epoch 24\n",
      "Train  Loss: 0.39154\n",
      "Train  Accuracy: 83.0%\n",
      "Test   Loss: 0.29514\n",
      "Test   Accuracy: 48.3%\n",
      "Epoch 25\n",
      "Train  Loss: 0.42451\n",
      "Train  Accuracy: 80.0%\n",
      "Test   Loss: 0.29410\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 26\n",
      "Train  Loss: 0.37418\n",
      "Train  Accuracy: 85.0%\n",
      "Test   Loss: 0.26566\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 27\n",
      "Train  Loss: 0.35942\n",
      "Train  Accuracy: 84.0%\n",
      "Test   Loss: 0.27371\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 28\n",
      "Train  Loss: 0.37682\n",
      "Train  Accuracy: 83.0%\n",
      "Test   Loss: 0.31713\n",
      "Test   Accuracy: 55.0%\n",
      "Epoch 29\n",
      "Train  Loss: 0.34180\n",
      "Train  Accuracy: 83.0%\n",
      "Test   Loss: 0.25558\n",
      "Test   Accuracy: 55.0%\n",
      "Epoch 30\n",
      "Train  Loss: 0.33804\n",
      "Train  Accuracy: 87.0%\n",
      "Test   Loss: 0.23607\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 31\n",
      "Train  Loss: 0.31833\n",
      "Train  Accuracy: 86.0%\n",
      "Test   Loss: 0.22466\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 32\n",
      "Train  Loss: 0.29819\n",
      "Train  Accuracy: 87.0%\n",
      "Test   Loss: 0.22489\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 33\n",
      "Train  Loss: 0.31271\n",
      "Train  Accuracy: 89.0%\n",
      "Test   Loss: 0.25959\n",
      "Test   Accuracy: 55.0%\n",
      "Epoch 34\n",
      "Train  Loss: 0.29346\n",
      "Train  Accuracy: 88.0%\n",
      "Test   Loss: 0.23652\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 35\n",
      "Train  Loss: 0.28696\n",
      "Train  Accuracy: 86.0%\n",
      "Test   Loss: 0.21594\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 36\n",
      "Train  Loss: 0.28489\n",
      "Train  Accuracy: 87.0%\n",
      "Test   Loss: 0.20726\n",
      "Test   Accuracy: 53.3%\n",
      "Epoch 37\n",
      "Train  Loss: 0.25652\n",
      "Train  Accuracy: 90.0%\n",
      "Test   Loss: 0.18759\n",
      "Test   Accuracy: 48.3%\n",
      "Epoch 38\n",
      "Train  Loss: 0.28203\n",
      "Train  Accuracy: 89.0%\n",
      "Test   Loss: 0.22545\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 39\n",
      "Train  Loss: 0.24893\n",
      "Train  Accuracy: 93.0%\n",
      "Test   Loss: 0.18108\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 40\n",
      "Train  Loss: 0.26061\n",
      "Train  Accuracy: 90.0%\n",
      "Test   Loss: 0.17375\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 41\n",
      "Train  Loss: 0.24927\n",
      "Train  Accuracy: 92.0%\n",
      "Test   Loss: 0.16775\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 42\n",
      "Train  Loss: 0.23456\n",
      "Train  Accuracy: 92.0%\n",
      "Test   Loss: 0.16739\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 43\n",
      "Train  Loss: 0.23747\n",
      "Train  Accuracy: 92.0%\n",
      "Test   Loss: 0.18819\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 44\n",
      "Train  Loss: 0.22980\n",
      "Train  Accuracy: 92.0%\n",
      "Test   Loss: 0.18161\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 45\n",
      "Train  Loss: 0.22144\n",
      "Train  Accuracy: 94.0%\n",
      "Test   Loss: 0.16714\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 46\n",
      "Train  Loss: 0.23376\n",
      "Train  Accuracy: 92.0%\n",
      "Test   Loss: 0.16477\n",
      "Test   Accuracy: 51.7%\n",
      "Epoch 47\n",
      "Train  Loss: 0.20786\n",
      "Train  Accuracy: 95.0%\n",
      "Test   Loss: 0.15308\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 48\n",
      "Train  Loss: 0.20483\n",
      "Train  Accuracy: 94.0%\n",
      "Test   Loss: 0.16679\n",
      "Test   Accuracy: 50.0%\n",
      "Epoch 49\n",
      "Train  Loss: 0.19644\n",
      "Train  Accuracy: 95.0%\n",
      "Test   Loss: 0.15150\n",
      "Test   Accuracy: 50.0%\n"
     ]
    }
   ],
   "source": [
    "\n",
    "from braindecode.torch_ext.util import np_to_var, var_to_np\n",
    "from braindecode.datautil.iterators import get_balanced_batches\n",
    "import torch.nn.functional as F\n",
    "from numpy.random import RandomState\n",
    "rng = RandomState(None)\n",
    "#rng = RandomState((2017,6,30))\n",
    "for i_epoch in range(50):\n",
    "    i_trials_in_batch = get_balanced_batches(len(train_set.X), rng, shuffle=True,\n",
    "                                            batch_size=32)\n",
    "    # Set model to training mode\n",
    "    model.train()\n",
    "    for i_trials in i_trials_in_batch:\n",
    "        # Have to add empty fourth dimension to X\n",
    "        batch_X = train_set.X[i_trials][:,:,:,None]\n",
    "        batch_y = train_set.y[i_trials]\n",
    "        net_in = np_to_var(batch_X)\n",
    "        if cuda:\n",
    "            net_in = net_in.cuda()\n",
    "        net_target = np_to_var(batch_y)\n",
    "        if cuda:\n",
    "            net_target = net_target.cuda()\n",
    "        # Remove gradients of last backward pass from all parameters\n",
    "        optimizer.zero_grad()\n",
    "        # Compute outputs of the network\n",
    "        outputs = model(net_in)\n",
    "        # Compute the loss\n",
    "        loss = F.nll_loss(outputs, net_target)\n",
    "        # Do the backpropagation\n",
    "        loss.backward()\n",
    "        # Update parameters with the optimizer\n",
    "        optimizer.step()\n",
    "\n",
    "    # Print some statistics each epoch\n",
    "    model.eval()\n",
    "    print(\"Epoch {:d}\".format(i_epoch))\n",
    "    for setname, dataset in (('Train', train_set), ('Test', test_set)):\n",
    "        i_trials_in_batch = get_balanced_batches(len(dataset.X), rng, batch_size=32, shuffle=False)\n",
    "        outputs = []\n",
    "        net_targets = []\n",
    "        for i_trials in i_trials_in_batch:\n",
    "            batch_X = train_set.X[i_trials][:,:,:,None]\n",
    "            batch_y = train_set.y[i_trials]\n",
    "            \n",
    "            net_in = np_to_var(batch_X)\n",
    "            if cuda:\n",
    "                net_in = net_in.cuda()\n",
    "            net_target = np_to_var(batch_y)\n",
    "            if cuda:\n",
    "                net_target = net_target.cuda()\n",
    "            net_target = var_to_np(net_target)\n",
    "            output = var_to_np(model(net_in))\n",
    "            outputs.append(output)\n",
    "            net_targets.append(net_target)\n",
    "        net_targets = np_to_var(np.concatenate(net_targets))\n",
    "        outputs = np_to_var(np.concatenate(outputs))\n",
    "        loss = F.nll_loss(outputs, net_targets)\n",
    "        print(\"{:6s} Loss: {:.5f}\".format(\n",
    "            setname, float(var_to_np(loss))))\n",
    "        predicted_labels = np.argmax(var_to_np(outputs), axis=1)\n",
    "        accuracy = np.mean(dataset.y  == predicted_labels)\n",
    "        print(\"{:6s} Accuracy: {:.1f}%\".format(\n",
    "            setname, accuracy * 100))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Problem: RAM not big enough\n",
    "# next session, manage batches through the hard drive\n",
    "# add analytics on training performance\n",
    "\n",
    "# rough results\n",
    "# Subject 1:--------------------------------------------\n",
    "# Epoch 49\n",
    "# Train  Loss: 0.00253\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00272\n",
    "# Test   Accuracy: 60.0%\n",
    "\n",
    "\n",
    "# Subject 2:--------------------------------------------\n",
    "# Epoch 49\n",
    "# Train  Loss: 0.00132\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00145\n",
    "# Test   Accuracy: 45.0%\n",
    "\n",
    "\n",
    "# Subject 3:--------------------------------------------\n",
    "# Epoch 27\n",
    "# Train  Loss: 0.00212\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00209\n",
    "# Test   Accuracy: 43.3%\n",
    "\n",
    "\n",
    "# Subject 4:--------------------------------------------\n",
    "# Epoch 34\n",
    "# Train  Loss: 0.00524\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00559\n",
    "# Test   Accuracy: 46.7%\n",
    "\n",
    "# Subject 5:--------------------------------------------\n",
    "# Epoch 33\n",
    "# Train  Loss: 0.01777\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00994\n",
    "# Test   Accuracy: 55.0%\n",
    "\n",
    "# Subject 6:\n",
    "# Epoch 49\n",
    "# Train  Loss: 0.00556\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00560\n",
    "# Test   Accuracy: 56.7%\n",
    "\n",
    "# Subject 7:\n",
    "# Epoch 49\n",
    "# Train  Loss: 0.00129\n",
    "# Train  Accuracy: 100.0%\n",
    "# Test   Loss: 0.00143\n",
    "# Test   Accuracy: 51.7%\n",
    "\n",
    "\n",
    "# Subject 8:\n",
    "# Epoch 49\n",
    "# Train  Loss: 0.19644\n",
    "# Train  Accuracy: 95.0%\n",
    "# Test   Loss: 0.15150\n",
    "# Test   Accuracy: 50.0%"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
